import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np

class GroupLinearLayer(nn.Module):
    """Modularized Linear Layer"""
    def __init__(self, num_blocks, din, dout, bias=True):
        super(GroupLinearLayer, self).__init__()

        self.bias=bias
        self.w = nn.Parameter(torch.Tensor(num_blocks, din, dout))
        self.b = nn.Parameter(torch.Tensor(1, num_blocks, dout))

        stdv = math.sqrt(6.0) / math.sqrt(din + dout)
        nn.init.uniform_(self.w, -stdv, stdv)
        nn.init.zeros_(self.b)

    def forward(self,x):
        # x - (bsz, num_blocks, din)
        x = x.permute(1,0,2)
        x = torch.bmm(x, self.w)
        x = x.permute(1,0,2)

        if self.bias:
            x = x + self.b

        return x

class Compositional_Transformer(nn.Module):
    def __init__(self, dim, search_dim, value_dim, search, retrieval, nonlinear, gumbel, bias):
        super(Compositional_Transformer, self).__init__()

        self.dim = dim
        self.search_dim = search_dim
        self.value_dim = value_dim
        self.head_dim = search_dim // search
        self.head_v_dim = value_dim // retrieval
        self.nonlinear = nonlinear
        self.search = search
        self.retrieval = retrieval
        self.scaling = self.head_dim ** -0.5
        self.gumbel = gumbel

        self.query_net = nn.Linear(dim, search_dim, bias=bias)
        self.key_net = nn.Linear(dim, search_dim, bias=bias)
        self.value_net = nn.Linear(dim, value_dim, bias=bias)

        assert(self.head_dim * search == search_dim)
        assert(self.head_v_dim * retrieval == value_dim)

        if self.nonlinear:
            self.out_proj = nn.Sequential(
                nn.Linear(self.search * self.head_v_dim, dim, bias=bias),
                nn.ReLU(),
                nn.Linear(dim, dim, bias=bias)
            )
        else:
            self.out_proj = nn.Linear(self.search * self.head_v_dim, dim, bias=bias)

    def forward(self, x, label):
        bsz, n, _ = x.shape

        q = self.query_net(x).view(bsz, n, self.search, self.head_dim) * self.scaling
        k = self.key_net(x).view(bsz, n, self.search, self.head_dim)
        v = self.value_net(x).view(bsz, n, self.retrieval, self.head_v_dim)

        q = q.transpose(2,1).contiguous()
        k = k.permute(0, 2, 3, 1).contiguous()
        v = v.transpose(2,1).contiguous().unsqueeze(1) # (bsz, 1, retrieval, n, head_v_dim)

        score = torch.matmul(q, k) # (bsz, search, n, n)
        mask = torch.zeros_like(score[0,0]).fill_diagonal_(1).unsqueeze(0).unsqueeze(0)
        mask = mask.repeat(bsz, self.search, 1, 1).bool()
        score.masked_fill_(mask, float('-inf'))

        if self.gumbel:
            score = F.gumbel_softmax(score, dim=-1).unsqueeze(2)
        else:
            score = F.softmax(score, dim=-1).unsqueeze(2) # (bsz, search, 1, n, n)

        out = torch.matmul(score, v).permute(0, 3, 1, 2, 4).reshape(bsz, n, self.search, self.retrieval, self.head_v_dim)
        v_score = label.unsqueeze(-1)
        out = (v_score * out).sum(dim=3).reshape(bsz, n, self.search * self.head_v_dim)

        return self.out_proj(out), score, v_score

class Model(nn.Module):
    def __init__(self, in_dim=3, dim=64, search_dim=64, value_dim=64,
                 search=4, retrieve=4, nonlinear=False,
                 bias=True, gumbel=False, v_s=2, v_p=2):
        super(Model, self).__init__()

        self.v_s = v_s
        self.v_p = v_p

        self.encoder = nn.Sequential(
            nn.Linear(in_dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 1)
        )

        self.model = Compositional_Transformer(dim, search_dim, value_dim, search, retrieve, nonlinear, gumbel, bias)

    def forward(self, x):
        bsz, n, _ = x.shape
        label = x[:,:,-(self.v_s * self.v_p):].view(bsz, n, self.v_s, self.v_p).detach()
        x = self.encoder(x)
        x, score, f_score = self.model(x, label)
        x = self.decoder(x)

        return x, score, f_score